import java.util.Arrays;
import java.util.Scanner;

public class C_1527 {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int r = sc.nextInt();
        int[] t = new int[n];
        for (int i = 0; i < n; i++) {
            t[i] = sc.nextInt();
        }
        int[] rr = new int[r];
        Arrays.sort(t);
        int next = 0;
        int re = 0;
        int man = n+r;
        while(true){
            for (int i = 0; i < rr.length; i++) {
                if (man==0)
                    break;
                if (rr[i]==0&&next==n){
                    rr[i] = -1;
                    man--;
                }
                if (rr[i]==0){
                    rr[i] = t[next];
                    next++;
                    rr[i]--;
                    man--;
                }else if (rr[i]!=-1)rr[i]--;
            }
            if (man==0)
                break;
            re+=man;
        }
        for (int i = 0; i < n; i++) {
            re-=t[i];
        }
        System.out.println(re);
    }
}
